# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import argparse
import os

import sys
import time

import pytest
import torch
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
from tensordict import TensorDict
from torchrl._utils import logger as torchrl_logger
from torchrl.data.replay_buffers import RemoteTensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.data.replay_buffers.writers import RoundRobinWriter

RETRY_COUNT = 3
RETRY_BACKOFF = 3


class ReplayBufferNode(RemoteTensorDictReplayBuffer):
    def __init__(self, capacity: int, scratch_dir=None):
        super().__init__(
            storage=LazyMemmapStorage(
                max_size=capacity, scratch_dir=scratch_dir, device=torch.device("cpu")
            ),
            sampler=RandomSampler(),
            writer=RoundRobinWriter(),
            collate_fn=lambda x: x,
        )


def construct_buffer_test(rank, name, world_size):
    if name == "TRAINER":
        buffer = _construct_buffer("BUFFER")
        assert type(buffer) is torch._C._distributed_rpc.PyRRef


def add_to_buffer_remotely_test(rank, name, world_size):
    if name == "TRAINER":
        buffer = _construct_buffer("BUFFER")
        res, _ = _add_random_tensor_dict_to_buffer(buffer)
        assert type(res) is int
        assert res == 0


def sample_from_buffer_remotely_returns_correct_tensordict_test(rank, name, world_size):
    if name == "TRAINER":
        buffer = _construct_buffer("BUFFER")
        _, inserted = _add_random_tensor_dict_to_buffer(buffer)
        sampled = _sample_from_buffer(buffer, 1)
        assert type(sampled) is type(inserted) is TensorDict
        a_sample = sampled["a"]
        a_insert = inserted["a"]
        assert (a_sample == a_insert).all()


@pytest.mark.skipif(
    sys.platform == "win32",
    reason="Distributed package support on Windows is a prototype feature and is subject to changes.",
)
@pytest.mark.parametrize("names", [["BUFFER", "TRAINER"]])
@pytest.mark.parametrize(
    "func",
    [
        construct_buffer_test,
        add_to_buffer_remotely_test,
        sample_from_buffer_remotely_returns_correct_tensordict_test,
    ],
)
def test_funcs(names, func):
    world_size = len(names)
    with mp.Pool(world_size) as pool:
        pool.starmap(
            init_rpc, ((rank, name, world_size) for rank, name in enumerate(names))
        )
        pool.starmap(
            func, ((rank, name, world_size) for rank, name in enumerate(names))
        )
        pool.apply_async(shutdown)


def init_rpc(rank, name, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    str_init_method = "tcp://localhost:10030"
    options = rpc.TensorPipeRpcBackendOptions(
        num_worker_threads=16, init_method=str_init_method
    )
    rpc.init_rpc(
        name,
        rank=rank,
        backend=rpc.BackendType.TENSORPIPE,
        rpc_backend_options=options,
        world_size=world_size,
    )


def shutdown():
    rpc.shutdown()


def _construct_buffer(target):
    for _ in range(RETRY_COUNT):
        try:
            buffer_rref = rpc.remote(target, ReplayBufferNode, args=(1000,))
            return buffer_rref
        except Exception as e:
            torchrl_logger.info(f"Failed to connect: {e}")
            time.sleep(RETRY_BACKOFF)
    raise RuntimeError("Unable to connect to replay buffer")


def _add_random_tensor_dict_to_buffer(buffer):
    rand_td = TensorDict({"a": torch.randint(100, (1,))}, [])
    return (
        rpc.rpc_sync(
            buffer.owner(),
            ReplayBufferNode.add,
            args=(
                buffer,
                rand_td,
            ),
        ),
        rand_td,
    )


def _sample_from_buffer(buffer, batch_size):
    return rpc.rpc_sync(
        buffer.owner(), ReplayBufferNode.sample, args=(buffer, batch_size)
    )


if __name__ == "__main__":
    args, unknown = argparse.ArgumentParser().parse_known_args()
    pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
